package com.chatplus.application.controller.api;

import cn.dev33.satoken.annotation.SaIgnore;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.RandomUtil;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.chatplus.application.aiprocessor.platform.image.ImgAiProcessorService;
import com.chatplus.application.aiprocessor.provider.ImgAiProcessorServiceProvider;
import com.chatplus.application.common.annotation.RateLimiter;
import com.chatplus.application.common.enumeration.LimitType;
import com.chatplus.application.common.exception.BadRequestException;
import com.chatplus.application.common.logging.SouthernQuietLogger;
import com.chatplus.application.common.logging.SouthernQuietLoggerFactory;
import com.chatplus.application.common.page.PageParam;
import com.chatplus.application.common.util.PlusJsonUtils;
import com.chatplus.application.domain.entity.draw.SdJobEntity;
import com.chatplus.application.domain.request.ImagePublishRequest;
import com.chatplus.application.domain.request.SdLocalImageRequest;
import com.chatplus.application.domain.request.SdStabilityImageRequest;
import com.chatplus.application.domain.response.PlusPageResponse;
import com.chatplus.application.domain.response.SdJobDetailApiResponse;
import com.chatplus.application.enumeration.AiPlatformEnum;
import com.chatplus.application.service.draw.SdJobService;
import com.chatplus.application.util.BaiduTransUtil;
import com.chatplus.application.web.basecontroller.BaseController;
import com.google.common.collect.Lists;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import org.apache.commons.lang3.StringUtils;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;

import java.util.Collections;
import java.util.List;

/**
 * SD API
 *
 * @Author: Angus
 */
@Validated
@RestController
@RequestMapping("/api/sd")
@Tag(name = "产品API", description = "SD API")
public class SdJobApiController extends BaseController {

    private static final SouthernQuietLogger LOGGER = SouthernQuietLoggerFactory.getLogger(SdJobApiController.class);
    private final SdJobService sdJobService;

    private final ImgAiProcessorServiceProvider imgAiProcessorServiceProvider;

    public SdJobApiController(SdJobService sdJobService,
                              ImgAiProcessorServiceProvider imgAiProcessorServiceProvider) {
        this.sdJobService = sdJobService;
        this.imgAiProcessorServiceProvider = imgAiProcessorServiceProvider;
    }

    /**
     * 获取任务列表或者创作记录
     * @return 任务列表
     */
    @GetMapping("/jobs")
    @Operation(summary = "获取任务列表或者创作记录")
    public List<SdJobDetailApiResponse> jobs(@RequestParam(value = "page") Integer page,
                                             @RequestParam(value = "page_size") Integer pageSize) {
        PageParam pageParam = new PageParam(page, pageSize);
        Page<SdJobEntity> pageDTO = sdJobService.getSdJobPage(pageParam, true, getUserId(), null);
        if (pageDTO == null || CollUtil.isEmpty(pageDTO.getRecords())) {
            return Collections.emptyList();
        }
        return pageDTO.getRecords().stream().map(SdJobDetailApiResponse::build).toList();
    }

    /**
     * 获取任务列表或者创作记录
     */
    @GetMapping("/runningJobs")
    @Operation(summary = "获取运行中的任务记录")
    public List<SdJobDetailApiResponse> runningJobs() {
        PageParam pageParam = new PageParam(1, Math.toIntExact(sdJobService.getRunningJobCount(getUserId())) + 10);
        Page<SdJobEntity> pageDTO = sdJobService.getSdJobPage(pageParam, false, getUserId(), null);
        if (pageDTO == null || CollUtil.isEmpty(pageDTO.getRecords())) {
            return Collections.emptyList();
        }
        return pageDTO.getRecords().stream().map(SdJobDetailApiResponse::build).toList();
    }

    @GetMapping("/finishJobs")
    @Operation(summary = "获取已完成的记录")
    public PlusPageResponse<SdJobDetailApiResponse> finishJobs(@RequestParam(value = "page") Integer page,
                                                               @RequestParam(value = "page_size") Integer pageSize) {
        PageParam pageParam = new PageParam(page, pageSize);
        Page<SdJobEntity> pageDTO = sdJobService.getSdJobPage(pageParam, true, getUserId(), null);
        if (pageDTO == null || CollUtil.isEmpty(pageDTO.getRecords())) {
            return new PlusPageResponse<>(page, pageSize, 0, 0, Lists.newArrayList());
        }
        List<SdJobDetailApiResponse> responseList = pageDTO.getRecords().stream().map(SdJobDetailApiResponse::build).toList();
        return new PlusPageResponse<>(pageParam.getCurrent(), pageParam.getSize()
                , pageDTO.getPages(), pageDTO.getTotal(), responseList);
    }



    @GetMapping("/imgWall")
    @Operation(summary = "获取任务列表或者创作记录")
    @SaIgnore
    public List<SdJobDetailApiResponse> imgWall(@RequestParam(value = "page", defaultValue = "0", required = false) Integer page,
                                                @RequestParam(value = "page_size", defaultValue = "0", required = false) Integer pageSize) {
        PageParam pageParam = new PageParam(page, pageSize);
        // 获取成功的任务
        Page<SdJobEntity> pageDTO = sdJobService.getSdJobPage(pageParam, true, null, true);
        if (pageDTO == null || CollUtil.isEmpty(pageDTO.getRecords())) {
            return Collections.emptyList();
        }
        return pageDTO.getRecords().stream().map(SdJobDetailApiResponse::build).toList();
    }

    /**
     * 画图任务
     */
    @PostMapping("/image")
    @Operation(summary = "画图任务")
    @RateLimiter(count = 1, time = 5, limitType = LimitType.IP)
    public void image(@RequestBody SdLocalImageRequest request) {
        ImgAiProcessorService sdImageProcessor = imgAiProcessorServiceProvider.getAiProcessorService(AiPlatformEnum.SD);
        String errMsg = sdImageProcessor.verifyUserRequest(getUserId());
        if (StringUtils.isNotEmpty(errMsg)) {
            throw new BadRequestException(errMsg);
        }
        request.setPrompt(BaiduTransUtil.getTransResult(request.getPrompt(), null, "en"));
        request.setTaskId(String.format("task(%s)", RandomUtil.randomString(15)));
        initMjJobEntity(request);
        sdImageProcessor.process(request);
    }

    /**
     * 画图任务
     */
    @PostMapping("/imageForStability")
    @Operation(summary = "画图任务")
    @RateLimiter(count = 1, time = 5, limitType = LimitType.IP)
    public void imageForStability(@RequestBody @Valid SdStabilityImageRequest request) {
        ImgAiProcessorService sdImageProcessor = imgAiProcessorServiceProvider.getAiProcessorService(AiPlatformEnum.SD_STABILITY);
        String errMsg = sdImageProcessor.verifyUserRequest(getUserId());
        if (StringUtils.isNotEmpty(errMsg)) {
            throw new BadRequestException(errMsg);
        }
        request.setPrompt(BaiduTransUtil.getTransResult(request.getPrompt(), null, "en"));
        initMjJobStabilityEntity(request);
        sdImageProcessor.process(request);
    }

    private void initMjJobStabilityEntity(SdStabilityImageRequest request) {
        SdJobEntity sdJobEntity = new SdJobEntity();
        sdJobEntity.setUserId(getUserId());
        sdJobEntity.setProgress(0);
        sdJobEntity.setPrompt(request.getPrompt());
        sdJobEntity.setPublish(false);
        sdJobEntity.setParams(PlusJsonUtils.toJsonString(request));
        sdJobService.save(sdJobEntity);
        request.setSdJobId(sdJobEntity.getId());
        request.setUserId(getUserId());
    }

    private void initMjJobEntity(SdLocalImageRequest request) {
        SdJobEntity sdJobEntity = new SdJobEntity();
        sdJobEntity.setUserId(getUserId());
        sdJobEntity.setProgress(0);
        sdJobEntity.setTaskId(request.getTaskId());
        sdJobEntity.setPrompt(request.getPrompt());
        sdJobEntity.setPublish(false);
        sdJobEntity.setParams(PlusJsonUtils.toJsonString(request));
        sdJobService.save(sdJobEntity);
        request.setSdJobId(sdJobEntity.getId());
        request.setUserId(getUserId());
    }
    /**
     * 删除图片
     */
    @PostMapping("/remove")
    @Operation(summary = "删除图片")
    public void remove(@RequestBody ImagePublishRequest request) {
        sdJobService.removeById(request.getId());
    }

    /**
     * 发布图片
     */
    @PostMapping("/publish")
    @Operation(summary = "发布图片")
    public void publish(@RequestBody ImagePublishRequest request) {
        sdJobService.update(new LambdaUpdateWrapper<SdJobEntity>()
                .set(SdJobEntity::getPublish, request.isAction())
                .eq(SdJobEntity::getId, request.getId()));
    }

}
